mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
Update batchnorm to have the running stats in parameters (#305)
This commit is contained in:

committed by
GitHub

parent
040c3bafab
commit
d29770eeaa
@@ -243,8 +243,15 @@ class BatchNorm(Module):
|
||||
self.bias = mx.zeros((num_features,))
|
||||
|
||||
if self.track_running_stats:
|
||||
self._running_mean = mx.zeros((num_features,))
|
||||
self._running_var = mx.ones((num_features,))
|
||||
self.running_mean = mx.zeros((num_features,))
|
||||
self.running_var = mx.ones((num_features,))
|
||||
self.freeze(keys=["running_mean", "running_var"], recurse=False)
|
||||
|
||||
def unfreeze(self, *args, **kwargs):
|
||||
"""Wrap unfreeze to make sure that running_mean and var are always
|
||||
frozen parameters."""
|
||||
super().unfreeze(*args, **kwargs)
|
||||
self.freeze(keys=["running_mean", "running_var"], recurse=False)
|
||||
|
||||
def _extra_repr(self):
|
||||
return (
|
||||
@@ -255,46 +262,47 @@ class BatchNorm(Module):
|
||||
|
||||
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
||||
"""
|
||||
Calculate the mean and variance of the input tensor.
|
||||
Calculate the mean and variance of the input tensor across the batch
|
||||
and spatial dimensions.
|
||||
|
||||
Args:
|
||||
x (mx.array): Input tensor.
|
||||
x (array): Input tensor.
|
||||
|
||||
Returns:
|
||||
tuple: Tuple containing mean and variance.
|
||||
"""
|
||||
reduction_axes = tuple(range(0, x.ndim - 1))
|
||||
means = mx.mean(x, axis=reduction_axes, keepdims=True)
|
||||
|
||||
mean = mx.mean(x, axis=reduction_axes, keepdims=True)
|
||||
var = mx.var(x, axis=reduction_axes, keepdims=True)
|
||||
|
||||
if self.track_running_stats and self.training:
|
||||
self._running_mean = (
|
||||
1 - self.momentum
|
||||
) * self._running_mean + self.momentum * means
|
||||
self._running_var = (
|
||||
1 - self.momentum
|
||||
) * self._running_var + self.momentum * var
|
||||
return means, var
|
||||
return mean, var
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
Forward pass of BatchNorm.
|
||||
|
||||
Args:
|
||||
x (mx.array): Input tensor.
|
||||
x (array): Input tensor.
|
||||
|
||||
Returns:
|
||||
mx.array: Output tensor.
|
||||
array: Normalized output tensor.
|
||||
"""
|
||||
|
||||
if x.ndim < 2 or x.ndim > 4:
|
||||
raise ValueError(
|
||||
f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}"
|
||||
)
|
||||
|
||||
if self.training or not self.track_running_stats:
|
||||
means, var = self._calc_stats(x)
|
||||
else:
|
||||
means, var = self._running_mean, self._running_var
|
||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||
# Calculate the mean and variance used to normalize the input x. If we
|
||||
# are in training mode update the running stats if needed.
|
||||
mean, var = self._calc_stats(x)
|
||||
if self.training and self.track_running_stats:
|
||||
mu = self.momentum
|
||||
self.running_mean = (1 - mu) * self.running_mean + mu * mean
|
||||
self.running_var = (1 - mu) * self.running_var + mu * var
|
||||
elif self.track_running_stats:
|
||||
mean = self.running_mean
|
||||
var = self.running_var
|
||||
|
||||
x = (x - mean) * mx.rsqrt(var + self.eps)
|
||||
return (self.weight * x + self.bias) if "weight" in self else x
|
||||
|
Reference in New Issue
Block a user