mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
improve batch norm code ^^
This commit is contained in:
parent
b444a6a693
commit
019a85511c
@ -227,7 +227,7 @@ class BatchNorm(Module):
|
||||
self.momentum = momentum
|
||||
self.affine = affine
|
||||
self.track_running_stats = track_running_stats
|
||||
self.dims_expanded = False
|
||||
self._dims_expanded = False
|
||||
|
||||
if self.affine:
|
||||
self.weight = mx.ones((num_features,))
|
||||
@ -268,7 +268,7 @@ class BatchNorm(Module):
|
||||
self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes)
|
||||
self.running_var = mx.expand_dims(self.running_var, self.reduction_axes)
|
||||
|
||||
self.dims_expanded = True
|
||||
self._dims_expanded = True
|
||||
|
||||
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
||||
"""
|
||||
@ -304,7 +304,7 @@ class BatchNorm(Module):
|
||||
mx.array: Output tensor.
|
||||
"""
|
||||
|
||||
if not self.dims_expanded:
|
||||
if not self._dims_expanded:
|
||||
self._check_and_expand_dims(x)
|
||||
|
||||
if self.training or not self.track_running_stats:
|
||||
|
Loading…
Reference in New Issue
Block a user