mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
improve batch norm code ^^
This commit is contained in:
parent
b444a6a693
commit
019a85511c
@ -227,7 +227,7 @@ class BatchNorm(Module):
|
|||||||
self.momentum = momentum
|
self.momentum = momentum
|
||||||
self.affine = affine
|
self.affine = affine
|
||||||
self.track_running_stats = track_running_stats
|
self.track_running_stats = track_running_stats
|
||||||
self.dims_expanded = False
|
self._dims_expanded = False
|
||||||
|
|
||||||
if self.affine:
|
if self.affine:
|
||||||
self.weight = mx.ones((num_features,))
|
self.weight = mx.ones((num_features,))
|
||||||
@ -268,7 +268,7 @@ class BatchNorm(Module):
|
|||||||
self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes)
|
self.running_mean = mx.expand_dims(self.running_mean, self.reduction_axes)
|
||||||
self.running_var = mx.expand_dims(self.running_var, self.reduction_axes)
|
self.running_var = mx.expand_dims(self.running_var, self.reduction_axes)
|
||||||
|
|
||||||
self.dims_expanded = True
|
self._dims_expanded = True
|
||||||
|
|
||||||
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
||||||
"""
|
"""
|
||||||
@ -304,7 +304,7 @@ class BatchNorm(Module):
|
|||||||
mx.array: Output tensor.
|
mx.array: Output tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not self.dims_expanded:
|
if not self._dims_expanded:
|
||||||
self._check_and_expand_dims(x)
|
self._check_and_expand_dims(x)
|
||||||
|
|
||||||
if self.training or not self.track_running_stats:
|
if self.training or not self.track_running_stats:
|
||||||
|
Loading…
Reference in New Issue
Block a user