Update python/mlx/nn/layers/normalization.py

Co-authored-by: Robert McCraith <mccraithrobert@gmail.com>
This commit is contained in:
__mo_san__ 2023-12-23 15:20:08 +01:00 committed by m0saan
parent 019a85511c
commit c68a472b83

View File

@ -234,8 +234,8 @@ class BatchNorm(Module):
self.bias = mx.zeros((num_features,))
if self.track_running_stats:
self.running_mean = mx.zeros((num_features,))
self.running_var = mx.ones((num_features,))
self._running_mean = mx.zeros((num_features,))
self._running_var = mx.ones((num_features,))
def _extra_repr(self):
return f"{self.num_features}, eps={self.eps}, momentum={self.momentum}, affine={'weight' in self}, track_running_stats={self.track_running_stats}"